{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "9dcd3717",
   "metadata": {},
   "source": [
    "# General Example\n",
    "In this example we show how to load a benchmark dataset, how to train a model on it and which are the different types of evaluation protocols that we can use."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "279954ab",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('../..')\n",
    "import os\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = '0'\n",
    "os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'\n",
    "os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'\n",
    "import tensorflow as tf\n",
    "tf.get_logger().setLevel('ERROR')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "92feb589",
   "metadata": {},
   "source": [
    "## Load the dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "65e97116",
   "metadata": {},
   "outputs": [],
   "source": [
    "import ampligraph\n",
    "# Benchmark datasets are under ampligraph.datasets module\n",
    "from ampligraph.datasets import load_fb15k_237\n",
    "# load fb15k-237 dataset\n",
    "dataset = load_fb15k_237()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9f9f3ecb",
   "metadata": {},
   "source": [
    "## Train the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "1c81d1d8",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/10\n",
      "29/29 [==============================] - 7s 233ms/step - loss: 36773.3906\n",
      "Epoch 2/10\n",
      "29/29 [==============================] - 6s 215ms/step - loss: 22626.3574\n",
      "Epoch 3/10\n",
      "29/29 [==============================] - 6s 202ms/step - loss: 17343.5254\n",
      "Epoch 4/10\n",
      "29/29 [==============================] - 6s 198ms/step - loss: 14640.1602\n",
      "Epoch 5/10\n",
      "29/29 [==============================] - 6s 223ms/step - loss: 13013.9121\n",
      "Epoch 6/10\n",
      "29/29 [==============================] - 6s 202ms/step - loss: 11937.1406\n",
      "Epoch 7/10\n",
      "29/29 [==============================] - 6s 194ms/step - loss: 11176.6133\n",
      "Epoch 8/10\n",
      "29/29 [==============================] - 6s 213ms/step - loss: 10612.2920\n",
      "Epoch 9/10\n",
      "29/29 [==============================] - 6s 196ms/step - loss: 10177.2930\n",
      "Epoch 10/10\n",
      "29/29 [==============================] - 6s 200ms/step - loss: 9833.4297\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<tensorflow.python.keras.callbacks.History at 0x3836fd1b0>"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Import the KGE model\n",
    "from ampligraph.latent_features import ScoringBasedEmbeddingModel\n",
    "\n",
    "# you can continue training from where you left after restoring the model\n",
    "tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir='./transe_train_logs')\n",
    "\n",
    "# create the model with transe scoring function\n",
    "model = ScoringBasedEmbeddingModel(eta=5,\n",
    "                                   k=300,\n",
    "                                   scoring_type='TransE')\n",
    "\n",
    "# you can either use optimizers/regularizers/loss/initializers with default values or you can \n",
    "# import it and customize the hyperparameters and pass it to compile\n",
    "\n",
    "# Let's create an adam optimizer with customized learning rate =0.005\n",
    "adam = tf.keras.optimizers.Adam(learning_rate=0.005)\n",
    "# Let's compile the model with self_advarsarial loss of default parameters\n",
    "model.compile(optimizer=adam, loss='self_adversarial')\n",
    "\n",
    "# fit the model to data.\n",
    "model.fit(dataset['train'],\n",
    "             batch_size=10000,\n",
    "             epochs=10,\n",
    "             callbacks=[tensorboard_callback])\n",
    "\n",
    "# the training can be visualised using the following command:\n",
    "# tensorboard --logdir='./transe_train_logs' --port=8891 \n",
    "# open the browser and go to the following URL: http://127.0.0.1:8891/"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "513f870b",
   "metadata": {},
   "source": [
    "## Predict scores"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "df6d51e8",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([-2.266111 , -2.095364 , -2.289988 , -3.8879156, -4.619501 ],\n",
       "      dtype=float32)"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pred = model.predict(dataset['test'][:5], \n",
    "                       batch_size=100)\n",
    "pred"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f74dbc21",
   "metadata": {},
   "source": [
    "## Evaluate the model (without filter)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cc312d21",
   "metadata": {},
   "source": [
    "### Subject and object (s,o) corruption\n",
    "This evaluation protocol consists in creating corrupted triples via the corruption of both subject and object of existing triples. This is considered the standard protocol for evaluation in Knowledge Graph Embedding models literature."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "a6c3b19f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "206/206 [==============================] - 46s 223ms/step\n",
      "MR: 472.3101575496624\n",
      "MRR: 0.08900155667769008\n",
      "hits@1: 0.0\n",
      "hits@10: 0.2428319796457579\n"
     ]
    }
   ],
   "source": [
    "# evaluate on the test set\n",
    "ranks = model.evaluate(dataset['test'],     # test set\n",
    "                       batch_size=100,      # evaluation batch size\n",
    "                       corrupt_side='s,o'   # sides to corrupt for scoring and ranking\n",
    "                       )\n",
    "\n",
    "# import the evaluation metrics\n",
    "from ampligraph.evaluation.metrics import mrr_score, hits_at_n_score, mr_score\n",
    "\n",
    "print('MR:', mr_score(ranks))\n",
    "print('MRR:', mrr_score(ranks))\n",
    "print('hits@1:', hits_at_n_score(ranks, 1))\n",
    "print('hits@10:', hits_at_n_score(ranks, 10))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4dc1efe8",
   "metadata": {},
   "source": [
    "### Object corruption\n",
    "Corruptions are generated by changing only the object of triples."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "66cb72a5",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "206/206 [==============================] - 25s 123ms/step\n",
      "MR: 268.11287797240436\n",
      "MRR: 0.12974471587250763\n",
      "hits@1: 0.0\n",
      "hits@10: 0.3512085331245719\n"
     ]
    }
   ],
   "source": [
    "# evaluate on the test set\n",
    "ranks = model.evaluate(dataset['test'], \n",
    "                       batch_size=100, \n",
    "                       corrupt_side='o' # corrupt only the object\n",
    "                       )\n",
    "\n",
    "# import the evaluation metrics\n",
    "from ampligraph.evaluation.metrics import mrr_score, hits_at_n_score, mr_score\n",
    "\n",
    "print('MR:', mr_score(ranks))\n",
    "print('MRR:', mrr_score(ranks))\n",
    "print('hits@1:', hits_at_n_score(ranks, 1))\n",
    "print('hits@10:', hits_at_n_score(ranks, 10))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "964063c6",
   "metadata": {},
   "source": [
    "### Subject corruption\n",
    "Corruptions are generated by changing only the subject of triples."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "4911c2c2",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "206/206 [==============================] - 27s 129ms/step\n",
      "MR: 676.5074371269204\n",
      "MRR: 0.048258397482872535\n",
      "hits@1: 0.0\n",
      "hits@10: 0.13445542616694392\n"
     ]
    }
   ],
   "source": [
    "# evaluate on the test set\n",
    "ranks = model.evaluate(dataset['test'], \n",
    "                       batch_size=100, \n",
    "                       corrupt_side='s' # corrupt only the subject\n",
    "                       )\n",
    "\n",
    "# import the evaluation metrics\n",
    "from ampligraph.evaluation.metrics import mrr_score, hits_at_n_score, mr_score\n",
    "\n",
    "print('MR:', mr_score(ranks))\n",
    "print('MRR:', mrr_score(ranks))\n",
    "print('hits@1:', hits_at_n_score(ranks, 1))\n",
    "print('hits@10:', hits_at_n_score(ranks, 10))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a5b60bad",
   "metadata": {},
   "source": [
    "## Evaluation with Filters\n",
    "Triples specified inside the filter are removed from the corruptions that are generated to avoid the creation of false negatives."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "e6004f04",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "206/206 [==============================] - 929s 5s/step\n",
      "MR: 364.22042274195127\n",
      "MRR: 0.19061439893688528\n",
      "hits@1: 0.12329973578628045\n",
      "hits@10: 0.32199823857520304\n"
     ]
    }
   ],
   "source": [
    "# evaluate on the test set\n",
    "ranks = model.evaluate(dataset['test'], \n",
    "                       batch_size=100, \n",
    "                       corrupt_side='s,o', # corrupt both subject and object\n",
    "                       use_filter={'train':dataset['train'], # Filter to be used for evaluation\n",
    "                                   'valid':dataset['valid'],\n",
    "                                   'test':dataset['test']}\n",
    "                       )\n",
    "\n",
    "# import the evaluation metrics\n",
    "from ampligraph.evaluation.metrics import mrr_score, hits_at_n_score, mr_score\n",
    "\n",
    "print('MR:', mr_score(ranks))\n",
    "print('MRR:', mrr_score(ranks))\n",
    "print('hits@1:', hits_at_n_score(ranks, 1))\n",
    "print('hits@10:', hits_at_n_score(ranks, 10))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0620885c",
   "metadata": {},
   "source": [
    "## Evaluation using a subset of entities for corruption\n",
    "Specify a subset of entities that are used for corrupting, depending on the evaluation strategy chosen, either the subject, the object or both. Notice that, despite not being the standard evaluation protocol, using a subset of entities can generate more meaningful corruptions and also reduce a lot the computational overhead caused by sampling corruptions among all the entities in the knowledge graph."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "864b6d01",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "12"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Let's get all the month present in training\n",
    "months = set(dataset['train'][\n",
    "    dataset['train'][:, 1] == \n",
    "        '/travel/travel_destination/climate./travel/travel_destination_monthly_climate/month'][:, 2])\n",
    "len(months)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "f910cab7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# consider we are evaluating the below test set which is specific to one predicate\n",
    "# This predicate tells the best time of the year(o) to visit a destination (s)\n",
    "dest_month_test_triples = dataset['test'][\n",
    "    dataset['test'][:, 1] ==\n",
    "        '/travel/travel_destination/climate./travel/travel_destination_monthly_climate/month']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "641446d7",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2/2 [==============================] - 2s 1s/step\n",
      "MR: 1.0833333333333333\n",
      "MRR: 0.9861111111111112\n",
      "hits@1: 0.9833333333333333\n",
      "hits@10: 1.0\n"
     ]
    }
   ],
   "source": [
    "# Let's say we want to evaluate this test set by corrupting the object with only months.\n",
    "# we can pass the months as entities_subset and generate corruptions only using this subset \n",
    "# instead of all entities in the graph\n",
    "# This approach is very useful when the graph size is big and/or \n",
    "# when our hypothesis belongs to a specific predicate type\n",
    "# When graph size is big we can randomly sample fixed number of small subset of entities and use it as corruption\n",
    "\n",
    "# evaluate on the test set\n",
    "ranks = model.evaluate(dest_month_test_triples, \n",
    "                       batch_size=100, \n",
    "                       corrupt_side='o', # corrupt only the object\n",
    "                       entities_subset=months,\n",
    "                       use_filter={'train':dataset['train'], # Filter to be used for evaluation\n",
    "                                   'valid':dataset['valid'],\n",
    "                                   'test':dataset['test']}\n",
    "                       )\n",
    "\n",
    "# import the evaluation metrics\n",
    "from ampligraph.evaluation.metrics import mrr_score, hits_at_n_score, mr_score\n",
    "\n",
    "print('MR:', mr_score(ranks))\n",
    "print('MRR:', mrr_score(ranks))\n",
    "print('hits@1:', hits_at_n_score(ranks, 1))\n",
    "print('hits@10:', hits_at_n_score(ranks, 10))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "01297daf",
   "metadata": {},
   "source": [
    "## Visualize the embeddings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "2ab4420b",
   "metadata": {},
   "outputs": [],
   "source": [
    "from ampligraph.utils import create_tensorboard_visualizations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "825cf70f",
   "metadata": {},
   "outputs": [],
   "source": [
    "create_tensorboard_visualizations(model, \n",
    "                                  entities_subset=['/m/027rn', '/m/06cx9', '/m/017dcd', '/m/06v8s0', '/m/07s9rl0'], \n",
    "                                  labels=['ent1', 'ent2', 'ent3', 'ent4', 'ent5'],\n",
    "                                  loc = './selected_subset_embeddings_vis')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f8807c88",
   "metadata": {},
   "outputs": [],
   "source": [
    "create_tensorboard_visualizations(model, \n",
    "                                  entities_subset='all',\n",
    "                                  loc = './full_embeddings_vis')\n",
    "\n",
    "# the embeddings can be visualised uncommenting the following command:\n",
    "# ! tensorboard --logdir='./full_embeddings_vis' --port=8891 \n",
    "# open the browser and go to the following URL: http://127.0.0.1:8891/#projector"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.10.6 ('base')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.6"
  },
  "vscode": {
   "interpreter": {
    "hash": "2e69f3670cdad0193847aaa0b77be56c05c951fcbdd384ff882dde0464f4de76"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
